-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add seq parallelism for attention and MoE MLP #1328
base: main
Are you sure you want to change the base?
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change and improvement! Overall, LGTM! One thing I'd like to see how is the performance impact for training. When it's ready, could you help run a benchmark on 8X7b (or other model size) with FSDP + EP sharding in dropping (with and without this change)? Capturing profiles will be great! Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Qinwen! Would be great if you could help run a few steps of training tests with profiles. Previously we met issues in optimization of inference, and some changes have degragation of training performance.
Description
with sp+ep, moe customer 2k seq inference improved by 20%
FIXES: b/374773995
Tests
tested on v6e/v5p:
SEQ=2048
python MaxText/inference_microbenchmark.py MaxText/configs/inference.yml max_prefill_predict_length=$SEQ max_target_length=6144 model_name=mixtral-8x7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_expert_parallelism=1 ici_context_parallelism=4 ici_tensor_parallelism=1 scan_layers=false per_device_batch_size=1 attention=dot_product megablox=False quantization=int8 checkpoint_is_quantized=True quantize_kvcache=True capacity_factor=1 tokenizer_path=assets/tokenizer.mistral-v3 compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3 enable_jax_profiler=True inference_microbenchmark_prefill_lengths="$SEQ" base_output_directory=$OUT_DIR run_name=$RUN_NAME profiler=xplane model_call_mode=inference inference_microbenchmark_stages=prefill
Checklist
Before submitting this PR, please make sure (put X in square brackets):